import torch
from torch import nn

class PixelNorm(nn.Module):
    def __init__(self, dim=1, eps=1e-4):
        super().__init__()
        self.dim = dim
        self.eps = eps
 
    def forward(self, x):
        return x / (torch.sqrt(torch.mean(x ** 2, dim=self.dim, keepdim=True)) + self.eps)